In [1]:
# Core Python & Utilities
import random
import re
from collections import defaultdict
from itertools import product
from tqdm import tqdm

# Data Handling
import numpy as np
import pandas as pd

# Stats & Modeling
from scipy.optimize import curve_fit
from scipy.stats import sem, t, ttest_ind
from sklearn.cluster import KMeans
import statsmodels.api as sm
from statsmodels.formula.api import ols
from statsmodels.stats.multicomp import pairwise_tukeyhsd
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
import shap
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix

# Data Viz
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import seaborn as sns

# Preferences
pd.set_option("display.max_columns", None)

Basic Demo of How it Works¶

In [2]:
# ------------------------------------------------------------------------------------
# MPR Equation 1: Arousal
# A_t = a * r   (original form)
# Recursive update (running average over time):
# A_t = α * r_t + (1 - α) * A_{t-1}
# where:
#   A_t is arousal at time t
#   a is specific activation (strength of reinforcer)
#   r_t is rate of reinforcement at time t (either 0 or 1)
#   α is the learning rate / smoothing factor (0 < α < 1)
# ------------------------------------------------------------------------------------
def update_arousal(prev_arousal: float, reinforcement: int, alpha: float, a: float) -> float:
    """Update arousal based on previous value and new reinforcement event"""
    return alpha * (a * reinforcement) + (1 - alpha) * prev_arousal

# ------------------------------------------------------------------------------------
# MPR Equation 2b: Response Constraint
# b_total = (1 / δ) * (A / (1 + A))
# where:
#   b_total is the total rate of behavior (target + competing responses)
#   δ is the duration of a response
#   A is current arousal
# ------------------------------------------------------------------------------------
def compute_total_response_rate(A: float, delta: float) -> float:
    """Compute total response rate given arousal and motor constraint"""
    return (1 / delta) * (A / (1 + A))


# ------------------------------------------------------------------------------------
# MPR Equation 7: Direction/Coupling
# b_target = C * (1 / δ) * (A / (1 + A))
# where:
#   b_target is the rate of the target response (e.g., lever press)
#   C is the coupling coefficient, representing directional control (0 to 1)
# ------------------------------------------------------------------------------------
def compute_target_response_rate(A: float, delta: float, coupling: float) -> float:
    """Compute target response rate based on arousal and coupling"""
    return coupling * (1 / delta) * (A / (1 + A))


# ------------------------------------------------------------------------------------
# Coupling update (based on blocking/proximity principles):
# C_t = C_{t-1} + η * (proximity - C_{t-1})
# where:
#   η is the learning rate
#   proximity = 1 if reinforced, 0.1 if not (we assume low proximity when no reinforcement)
# ------------------------------------------------------------------------------------
def update_coupling(prev_coupling: float, was_reinforced: bool, learning_rate: float = 0.05, proximity: float = 1.0) -> float:
    """Update coupling coefficient based on reinforcement outcome"""
    target = proximity if was_reinforced else 0.1
    return prev_coupling + learning_rate * (target - prev_coupling)


# ------------------------------------------------------------------------------------
# Wrapper for a single update step
# Inputs:
#   prev_arousal: current arousal level
#   prev_coupling: current coupling strength
#   reinforcement: 1 or 0
#   a: specific activation (magnitude of reinforcer)
#   delta: duration of a response
#   alpha: learning rate for arousal
#   eta: learning rate for coupling
# Outputs:
#   updated A, C, total response rate, target response rate
# ------------------------------------------------------------------------------------
def update_step(prev_arousal, prev_coupling, reinforcement, a, delta, alpha=0.1, eta=0.05):
    """Run one update of arousal, coupling, and response rates"""
    A = update_arousal(prev_arousal, reinforcement, alpha, a)  # now includes a
    C = update_coupling(prev_coupling, was_reinforced=bool(reinforcement), learning_rate=eta)
    b_total = compute_total_response_rate(A, delta)
    b_target = compute_target_response_rate(A, delta, C)
    return A, C, b_total, b_target
In [3]:
# Parameters
num_steps = 1000
delta = 0.25         # response duration (sec)
a = 0.6              # specific activation (motivational level)
alpha = 0.1          # learning rate for arousal
eta = 0.05           # learning rate for coupling
vr_n = 10             # VR-5 schedule: 1 in 5 chance of reinforcement per step

# Initialize state
A = 0.6              # start with moderate arousal
C = 0.5              # start with medium coupling
arousal_history = []
coupling_history = []
b_total_history = []
b_target_history = []
reinforcement_history = []

# Simulation loop
for t in range(num_steps):
    # simulate probabilistic reinforcement (VR schedule)
    reinforcement = np.random.rand() < (1 / vr_n)

    # update arousal, coupling, and response rates
    A, C, b_total, b_target = update_step(A, C, reinforcement, a, delta, alpha, eta)

    # record
    arousal_history.append(A)
    coupling_history.append(C)
    b_total_history.append(b_total)
    b_target_history.append(b_target)
    reinforcement_history.append(reinforcement)

# Plotting
fig, axs = plt.subplots(4, 1, figsize=(12, 10), sharex=True)

# Get times where reinforcement occurred
reinforced_times = [i for i, r in enumerate(reinforcement_history) if r == 1]

# Function to add vertical lines at correct time steps
def add_reinforcement_lines(ax):
    for rt in reinforced_times:
        ax.axvline(x=rt - 0.5, color='red', alpha=0.3, linewidth=0.8)  # Shift back by 0.5 to align with event

# Arousal
axs[0].plot(arousal_history, color='k', label="Arousal (A)")
add_reinforcement_lines(axs[0])
axs[0].set_ylabel("Arousal", fontsize=16)
axs[0].legend(fontsize=14)

# Coupling
axs[1].plot(coupling_history, color='k', label="Coupling (C)")
add_reinforcement_lines(axs[1])
axs[1].set_ylabel("Coupling", fontsize=16)
axs[1].legend(fontsize=14)

# Total response rate
axs[2].plot(b_total_history, color='k', label="Total Response Rate")
add_reinforcement_lines(axs[2])
axs[2].set_ylabel("Total Rate (resp/sec)", fontsize=16)
axs[2].legend(fontsize=14)

# Target response rate
axs[3].plot(b_target_history, color='k', label="Target Response Rate")
add_reinforcement_lines(axs[3])
axs[3].set_ylabel("Target Rate", fontsize=16)
axs[3].legend(fontsize=14)

axs[3].set_xlabel("Time Step", fontsize=20)

plt.tight_layout()
plt.show()
No description has been provided for this image

A Range of Single VR Schedule¶

In [4]:
import numpy as np
import pandas as pd
from itertools import product
from tqdm import tqdm
from scipy.stats import sem, t

# --- MPR Model Functions ---
def update_arousal(prev_arousal: float, reinforcement: int, alpha: float, a: float) -> float:
    return alpha * (a * reinforcement) + (1 - alpha) * prev_arousal

def compute_total_response_rate(A: float, delta: float) -> float:
    return (1 / delta) * (A / (1 + A))

def compute_target_response_rate(A: float, delta: float, coupling: float) -> float:
    return coupling * (1 / delta) * (A / (1 + A))

def update_coupling(prev_coupling: float, was_reinforced: bool, learning_rate: float, proximity: float = 1.0) -> float:
    target = proximity if was_reinforced else 0.1
    return prev_coupling + learning_rate * (target - prev_coupling)

def update_step(prev_arousal, prev_coupling, reinforcement, a, delta, alpha, eta):
    A = update_arousal(prev_arousal, reinforcement, alpha, a)
    C = update_coupling(prev_coupling, was_reinforced=bool(reinforcement), learning_rate=eta)
    b_total = compute_total_response_rate(A, delta)
    b_target = compute_target_response_rate(A, delta, C)
    return A, C, b_total, b_target

# --- Confidence Interval ---
def ci95(data):
    n = len(data)
    return sem(data) * t.ppf(0.975, n - 1)

# --- Main Simulation Function ---
def simulate_mpr_sequence(vr_schedule_list, a, delta, alpha, eta, ao_id, steps_per_schedule=1000):
    A, C = 0.6, 0.5  # Initial internal state
    all_results = []

    for vr_idx, vr in enumerate(vr_schedule_list):
        A_hist, C_hist, bt_hist, bp_hist = [], [], [], []

        for _ in range(steps_per_schedule):
            reinforcement = np.random.rand() < (1 / vr)
            A, C, b_total, b_target = update_step(A, C, reinforcement, a, delta, alpha, eta)
            A_hist.append(A)
            C_hist.append(C)
            bt_hist.append(b_total)
            bp_hist.append(b_target)

        all_results.append({
            'AO_id': ao_id,
            'vr_sequence_index': vr_idx,
            'VR': vr,
            'arousal_avg': np.mean(A_hist),
            'arousal_min': np.min(A_hist),
            'arousal_max': np.max(A_hist),
            'arousal_ci': ci95(A_hist),
            'coupling_avg': np.mean(C_hist),
            'coupling_min': np.min(C_hist),
            'coupling_max': np.max(C_hist),
            'coupling_ci': ci95(C_hist),
            'b_total_avg': np.mean(bt_hist),
            'b_total_min': np.min(bt_hist),
            'b_total_max': np.max(bt_hist),
            'b_total_ci': ci95(bt_hist),
            'b_target_avg': np.mean(bp_hist),
            'b_target_min': np.min(bp_hist),
            'b_target_max': np.max(bp_hist),
            'b_target_ci': ci95(bp_hist),
            'activation': a,
            'delta': delta,
            'alpha': alpha,
            'eta': eta,
        })

    return all_results

# --- Parameter Grid ---
deltas = [0.25, 0.50, 1.0, 2.0]
activations = [0.1, 0.2, 0.4, 0.8, 1.6]
alphas = [0.01, 0.03, 0.1, 0.3, 1.0]
etas = [0.01, 0.03, 0.3, 1.0]
vr_schedules = [1, 3, 10, 30, 100, 300, 1000]

# --- Grid Search Over AO Parameter Combinations ---
results = []
grid = list(product(activations, deltas, alphas, etas))

for ao_id, (a, delta, alpha, eta) in enumerate(tqdm(grid, desc="Running simulations per AO")):
    shuffled_vrs = np.random.permutation(vr_schedules)
    sim_result_list = simulate_mpr_sequence(
        vr_schedule_list=shuffled_vrs,
        a=a,
        delta=delta,
        alpha=alpha,
        eta=eta,
        ao_id=ao_id
    )
    results.extend(sim_result_list)

# --- Final Results DataFrame ---
results_df = pd.DataFrame(results)

# Sample Output
results_df
Running simulations per AO: 100%|██████████| 400/400 [00:44<00:00,  9.08it/s]
Out[4]:
AO_id vr_sequence_index VR arousal_avg arousal_min arousal_max arousal_ci coupling_avg coupling_min coupling_max coupling_ci b_total_avg b_total_min b_total_max b_total_ci b_target_avg b_target_min b_target_max b_target_ci activation delta alpha eta
0 0 0 1 0.149498 0.100022 0.595000 6.165828e-03 0.950502 0.505000 0.999978 0.006166 0.498632 0.363708 1.492163 1.557839e-02 0.449135 0.363700 0.806896 8.456957e-03 0.1 0.25 0.01 0.01
1 0 1 300 0.010124 0.000008 0.099021 1.228692e-03 0.191092 0.100075 0.990979 0.011056 0.038640 0.000033 0.360398 4.592922e-03 0.020553 0.000003 0.357147 3.499529e-03 0.1 0.25 0.01 0.01
2 0 2 100 0.000697 0.000019 0.002307 3.128447e-05 0.106276 0.100172 0.120761 0.000282 0.002786 0.000077 0.009206 1.249306e-04 0.000305 0.000008 0.001112 1.437479e-05 0.1 0.25 0.01 0.01
3 0 3 3 0.033054 0.000811 0.042457 4.928778e-04 0.397488 0.107300 0.482112 0.004436 0.127754 0.003242 0.162911 1.875522e-03 0.052939 0.000348 0.078541 1.065373e-03 0.1 0.25 0.01 0.01
4 0 4 1000 0.003726 0.000009 0.035269 4.465624e-04 0.133534 0.100078 0.417419 0.004019 0.014648 0.000035 0.136269 1.741781e-03 0.003772 0.000003 0.056881 5.896975e-04 0.1 0.25 0.01 0.01
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
2795 399 2 1 1.600000 1.600000 1.600000 1.378580e-17 1.000000 1.000000 1.000000 0.000000 0.307692 0.307692 0.307692 6.892901e-18 0.307692 0.307692 0.307692 6.892901e-18 1.6 2.00 1.00 1.00
2796 399 3 300 0.006400 0.000000 1.600000 6.270057e-03 0.103600 0.100000 1.000000 0.003527 0.001231 0.000000 0.307692 1.205780e-03 0.001231 0.000000 0.307692 1.205780e-03 1.6 2.00 1.00 1.00
2797 399 4 1000 0.001600 0.000000 1.600000 3.139746e-03 0.100900 0.100000 1.000000 0.001766 0.000308 0.000000 0.307692 6.037974e-04 0.000308 0.000000 0.307692 6.037974e-04 1.6 2.00 1.00 1.00
2798 399 5 3 0.542400 0.000000 1.600000 4.702322e-02 0.405100 0.100000 1.000000 0.026451 0.104308 0.000000 0.307692 9.042928e-03 0.104308 0.000000 0.307692 9.042928e-03 1.6 2.00 1.00 1.00
2799 399 6 30 0.056000 0.000000 1.600000 1.825616e-02 0.131500 0.100000 1.000000 0.010269 0.010769 0.000000 0.307692 3.510801e-03 0.010769 0.000000 0.307692 3.510801e-03 1.6 2.00 1.00 1.00

2800 rows × 23 columns

In [5]:
# Sort the data so each AO's line is plotted in VR sequence order
results_df_sorted = results_df.sort_values(['AO_id', 'VR'])
results_df_sorted['b_target_avg'] = results_df_sorted['b_target_avg'].clip(lower=0.0000001)
results_df_sorted['b_target_avg'] = results_df_sorted['b_target_avg']*60

# Set up the figure
fig, ax = plt.subplots(figsize=(8, 6))

# Plot each AO's line
for _, ao_data in results_df_sorted.groupby("AO_id"):
    ax.plot(
        ao_data['VR'],
        ao_data['b_target_avg'],
        color='gray',
        alpha=0.2,
        linewidth=1
    )

# Optional: Overlay a central tendency (e.g., median or mean line)
avg_data = results_df_sorted.groupby('VR')['b_target_avg'].mean().reset_index()
sns.lineplot(
    data=avg_data,
    x='VR', y='b_target_avg',
    color='black',
    linewidth=2,
    label='Mean across AOs'
)

# Log scales for both axes
ax.set_yscale("log")
ax.set_xscale("log")
ax.set_ylim(0.000001, 300)
ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.5f'))
ax.xaxis.set_major_formatter(ticker.FormatStrFormatter('%.0f'))

# Labels and style
ax.set_ylabel("Target Response Rate", fontsize=20)
ax.set_xlabel("VR Schedule", fontsize=20)
ax.legend().remove()
ax.grid(False)
sns.despine()
plt.tight_layout()
plt.show()
No description has been provided for this image
In [6]:
# Compute reinforcement rate
results_df['r'] = 1 / results_df['VR']
results_df['b_target_avg'] = results_df['b_target_avg']*60

# Define generalized hyperbola
def generalized_hyperbola(r, a, b, c):
    return (b * (r ** a)) / (r ** a + c)

# Fit function per AO
def fit_ao_hyperbola(df):
    try:
        popt, _ = curve_fit(
            generalized_hyperbola,
            df['r'],
            df['b_target_avg'],
            p0=[1.0, 1.0, 0.1],
            bounds=(0, np.inf),
            maxfev=10000
        )
        return popt
    except:
        return [np.nan, np.nan, np.nan]

# Fit and store per-AO curves
fit_results = []
ao_curves = []

# Loop over AOs
for ao_id, group in results_df.groupby("AO_id"):
    try:
        # Fit hyperbola
        popt, _ = curve_fit(
            generalized_hyperbola,
            group['r'],
            group['b_target_avg'],
            p0=[1.0, 1.0, 0.1],
            bounds=(0, np.inf),
            maxfev=10000
        )

        # Predict using the fit
        predicted = generalized_hyperbola(group['r'], *popt)

        # Compute R²
        residuals = group['b_target_avg'] - predicted
        ss_res = np.sum(residuals ** 2)
        ss_tot = np.sum((group['b_target_avg'] - np.mean(group['b_target_avg'])) ** 2)
        r_squared = 1 - (ss_res / ss_tot)

        # Store parameters and R²
        fit_results.append({
            'AO_id': ao_id,
            'a': popt[0],
            'b': popt[1],
            'c': popt[2],
            'r_squared': r_squared
        })

        # Store predicted curve
        r_vals = np.linspace(group['r'].min(), group['r'].max(), 200)
        predicted_curve = generalized_hyperbola(r_vals, *popt)
        ao_curves.append(pd.DataFrame({
            'AO_id': ao_id,
            'r': r_vals,
            'predicted_b_target': predicted_curve
        }))

    except:
        continue  # skip AOs that fail to fit

fit_df = pd.DataFrame(fit_results)
curves_df = pd.concat(ao_curves, ignore_index=True)

# Group-level mean and CI
grouped = results_df.groupby('r').agg(
    b_target_avg_mean=('b_target_avg', 'mean'),
    b_target_avg_se=('b_target_avg', sem)
).reset_index()
grouped['b_target_avg_ci'] = grouped['b_target_avg_se'] * t.ppf(0.975, df=len(results_df) - 1)

# Fit group-level curve
popt, _ = curve_fit(
    generalized_hyperbola,
    grouped['r'],
    grouped['b_target_avg_mean'],
    p0=[1.0, 1.0, 0.1],
    bounds=(0, np.inf)
)
r_vals = np.linspace(grouped['r'].min(), grouped['r'].max(), 200)
group_fit = generalized_hyperbola(r_vals, *popt)

# R² calculation for group-level fit
residuals = grouped['b_target_avg_mean'] - generalized_hyperbola(grouped['r'], *popt)
ss_res = np.sum(residuals ** 2)
ss_tot = np.sum((grouped['b_target_avg_mean'] - np.mean(grouped['b_target_avg_mean'])) ** 2)
r_squared = 1 - (ss_res / ss_tot)

# Plot
plt.figure(figsize=(8, 6))

# Plot all individual AO fits
for ao_id, ao_curve in curves_df.groupby("AO_id"):
    plt.plot(ao_curve['r'], ao_curve['predicted_b_target'], alpha=0.1, color='gray')

# Plot group-level fit
plt.plot(r_vals, group_fit, color='black', linewidth=2, label='Group Mean Fit')

# Overlay mean ± CI as white points with black outlines
plt.errorbar(
    grouped['r'],
    grouped['b_target_avg_mean'],
    yerr=grouped['b_target_avg_ci'],
    fmt='o',
    color='white',
    markersize=5,
    elinewidth=1.5,
    ecolor='black',
    capsize=4,
    markeredgecolor='black',
    label='Mean ± 95% CI'
)

# Add annotation for group-level fit parameters
plt.text(
    x=1,
    y=0.004,
    s=f"$a$ = {popt[0]:.2f}\n$b$ = {popt[1]:.2f}\n$c$ = {popt[2]:.2f}\n$r^2$ = {r_squared:.4f}",
    fontsize=20,
    ha='right'
)

# Style
plt.yscale("log")
plt.xlabel("Reinforcement Rate", fontsize=30)
plt.ylabel("Target Response Rate", fontsize=30)
plt.gca().yaxis.set_major_formatter(ticker.FormatStrFormatter('%.4f'))
plt.legend().remove()
plt.ylim(.0001, 300)
plt.grid(False)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
sns.despine()
plt.tight_layout()
plt.show()
No description has been provided for this image
In [8]:
fig, ax = plt.subplots(figsize=(1, 3))
sns.stripplot(fit_df['r_squared'], color='k', alpha=0.2)
sns.despine(top=True, right=True)
plt.ylabel(r"$r^2$", fontsize=14, labelpad=12)
plt.show()
No description has been provided for this image
In [9]:
for vr_val in results_df['VR'].unique():
  df_vr = results_df[results_df['VR'] == vr_val]

  # Metric column mapping
  metrics = {
      'arousal_avg': ('Arousal', 'arousal_ci'),
      'coupling_avg': ('Coupling', 'coupling_ci'),
      'b_target_avg': ('Target Response Rate', 'b_target_ci'),
      'b_total_avg': ('Total Response Rate', 'b_total_ci')
  }

  params = ['activation', 'delta', 'alpha', 'eta']

  # Create 4x4 subplot grid
  fig, axs = plt.subplots(4, 4, figsize=(14, 14), sharex='col', sharey=False)
  fig.subplots_adjust(hspace=0.4, wspace=0.4)

  # Plot each metric × parameter
  for row_idx, (metric_col, (metric_label, ci_col)) in enumerate(metrics.items()):
      for col_idx, param in enumerate(params):
          ax = axs[row_idx, col_idx]

          # Get data for this combo
          sub_df = df_vr[[param, metric_col, ci_col]].copy()

          # Convert the x (param) to categorical codes for spacing
          sub_df['x_val'] = pd.Categorical(sub_df[param]).codes
          x_jittered = sub_df['x_val'] + np.random.uniform(-0.2, 0.2, size=len(sub_df))

          # Plot strip manually
          ax.scatter(x_jittered, sub_df[metric_col], color='black', alpha=0.7, s=15, zorder=2)

          # Plot error bars (CIs)
          ax.errorbar(
              x_jittered,
              sub_df[metric_col],
              yerr=sub_df[ci_col],
              fmt='none',
              ecolor='gray',
              elinewidth=0.8,
              capsize=2,
              alpha=0.6,
              zorder=1
          )

          # Add symbolic labels
          param_symbols = {
              'activation': 'Activation',
              'delta': r'$\delta$',
              'alpha': r'$\alpha$',
              'eta': r'$\eta$'
          }

          # Format axes
          if row_idx == 3:
              ax.set_xlabel(param_symbols[param], fontsize=20, labelpad=8, color='k')
              ax.set_xticks(range(len(sub_df[param].unique())))
              ax.set_xticklabels(sorted(sub_df[param].unique()), fontsize=10, color='k')
          else:
              ax.set_xlabel("")
              ax.set_xticklabels([])
              ax.tick_params(labelbottom=False)

          if col_idx == 0:
              ax.set_ylabel(metric_label, fontsize=20, labelpad=8, color='k')
          else:
              ax.set_ylabel("")
              ax.set_yticklabels([])

          ax.grid(False)
          sns.despine(top=True, right=True)

  plt.tight_layout(rect=[0, 0, 1, 0.97])
  plt.suptitle(f"VR Schedule = {vr_val}", fontsize=20)
  plt.show()
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

A Range of Single VI Schedules¶

In [10]:
# --- MPR Model Equations ---
def update_arousal(prev_arousal: float, reinforcement: int, alpha: float, a: float) -> float:
    return alpha * (a * reinforcement) + (1 - alpha) * prev_arousal

def compute_total_response_rate(A: float, delta: float) -> float:
    return (1 / delta) * (A / (1 + A))

def compute_target_response_rate(A: float, delta: float, coupling: float) -> float:
    return coupling * (1 / delta) * (A / (1 + A))

def update_coupling(prev_coupling: float, was_reinforced: bool, learning_rate: float, proximity: float = 1.0) -> float:
    target = proximity if was_reinforced else 0.1
    return prev_coupling + learning_rate * (target - prev_coupling)

def update_step(prev_arousal, prev_coupling, reinforcement, a, delta, alpha, eta):
    A = update_arousal(prev_arousal, reinforcement, alpha, a)
    C = update_coupling(prev_coupling, was_reinforced=bool(reinforcement), learning_rate=eta)
    b_total = compute_total_response_rate(A, delta)
    b_target = compute_target_response_rate(A, delta, C)
    return A, C, b_total, b_target

# --- CI Helper ---
def ci95(data):
    n = len(data)
    return sem(data) * t.ppf(0.975, n - 1)

# --- Updated AO Simulation for Sequential VI ---
def simulate_mpr_vi_sequence(vi_schedule_list, a, delta, alpha, eta, ao_id, steps_per_schedule=1000, avg_peck_rate=2.5):
    A, C = 0.6, 0.5
    all_results = []

    for vi_idx, vi_sec in enumerate(vi_schedule_list):
        A_hist, C_hist, bt_hist, bp_hist = [], [], [], []

        avg_irt = 1.0 / avg_peck_rate
        vi_times = np.cumsum(np.random.exponential(scale=vi_sec, size=steps_per_schedule))
        current_time = 0
        next_reinforcement_time = vi_times[0]
        vi_index = 0

        for _ in range(steps_per_schedule):
            current_time += avg_irt
            reinforcement = 0
            if current_time >= next_reinforcement_time:
                reinforcement = 1
                vi_index += 1
                if vi_index < len(vi_times):
                    next_reinforcement_time = vi_times[vi_index]

            A, C, b_total, b_target = update_step(A, C, reinforcement, a, delta, alpha, eta)
            A_hist.append(A)
            C_hist.append(C)
            bt_hist.append(b_total)
            bp_hist.append(b_target)

        all_results.append({
            'AO_id': ao_id,
            'vi_sequence_index': vi_idx,
            'VI': vi_sec,
            'arousal_avg': np.mean(A_hist),
            'arousal_min': np.min(A_hist),
            'arousal_max': np.max(A_hist),
            'arousal_ci': ci95(A_hist),
            'coupling_avg': np.mean(C_hist),
            'coupling_min': np.min(C_hist),
            'coupling_max': np.max(C_hist),
            'coupling_ci': ci95(C_hist),
            'b_total_avg': np.mean(bt_hist),
            'b_total_min': np.min(bt_hist),
            'b_total_max': np.max(bt_hist),
            'b_total_ci': ci95(bt_hist),
            'b_target_avg': np.mean(bp_hist),
            'b_target_min': np.min(bp_hist),
            'b_target_max': np.max(bp_hist),
            'b_target_ci': ci95(bp_hist),
            'activation': a,
            'delta': delta,
            'alpha': alpha,
            'eta': eta,
        })

    return all_results


# Parameters
deltas = [0.25, 0.50, 1.0, 2.0]
activations = [0.1, 0.2, 0.4, 0.8, 1.6]
alphas = [0.01, 0.03, 0.1, 0.3, 1.0]
etas = [0.01, 0.03, 0.3, 1.0]
vi_schedules = [1, 3, 10, 30, 100, 300, 1000]

# Run each AO across all VI schedules in random order
results_vi = []
grid = list(product(activations, deltas, alphas, etas))

for ao_id, (a, delta, alpha, eta) in enumerate(tqdm(grid, desc="Running AO simulations through VI schedules")):
    shuffled_vis = np.random.permutation(vi_schedules)
    sim_result_list = simulate_mpr_vi_sequence(shuffled_vis, a, delta, alpha, eta, ao_id=ao_id)
    results_vi.extend(sim_result_list)

# Final DataFrame
results_vi_df = pd.DataFrame(results_vi)

# Convert responses per sec to response per min
results_vi_df['b_target_avg'] = results_vi_df['b_target_avg']*60
Running AO simulations through VI schedules: 100%|██████████| 400/400 [00:16<00:00, 24.26it/s]
In [11]:
fig, ax = plt.subplots(figsize=(5, 5))
sns.stripplot(
    x='VI',
    y='b_target_avg',
    data=results_vi_df,
    color='k',
    alpha=0.2
)
plt.yscale("log")
plt.gca().yaxis.set_major_formatter(ticker.FormatStrFormatter('%.8f'))
plt.ylim(0.000001, 100)
ax.grid(False)
sns.despine(top=True, right=True)
plt.ylabel("Response Rate", fontsize=18)
plt.xlabel("VI Schedule", labelpad=12, fontsize=18)
plt.show()
No description has been provided for this image
In [12]:
# Sort the data so each AO's line is plotted in VR sequence order
results_df_sorted = results_vi_df.sort_values(['AO_id', 'VI'])
results_df_sorted['b_target_avg'] = results_df_sorted['b_target_avg'].clip(lower=0.00001)

# Set up the figure
fig, ax = plt.subplots(figsize=(8, 6))

# Plot each AO's line
for _, ao_data in results_df_sorted.groupby("AO_id"):
    ax.plot(
        ao_data['VI'],
        ao_data['b_target_avg'],
        color='gray',
        alpha=0.2,
        linewidth=1
    )

# Optional: Overlay a central tendency (e.g., median or mean line)
avg_data = results_df_sorted.groupby('VI')['b_target_avg'].mean().reset_index()
sns.lineplot(
    data=avg_data,
    x='VI', y='b_target_avg',
    color='black',
    linewidth=2,
    label='Mean across AOs'
)

# Log scales for both axes
ax.set_yscale("log")
ax.set_xscale("log")
ax.set_ylim(0.00001, 300)
ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.5f'))
ax.xaxis.set_major_formatter(ticker.FormatStrFormatter('%.0f'))

# Labels and style
ax.set_ylabel("Target Response Rate", fontsize=20)
ax.set_xlabel("VI Schedule", fontsize=20)
ax.legend().remove()
ax.grid(False)
sns.despine()
plt.tight_layout()
plt.show()
No description has been provided for this image
In [13]:
# Compute reinforcement rate
results_vi_df['r'] = 1 / results_vi_df['VI']

# Define generalized hyperbola
def generalized_hyperbola(r, a, b, c):
    return (b * (r ** a)) / (r ** a + c)

# Fit function per AO
def fit_ao_hyperbola(df):
    try:
        popt, _ = curve_fit(
            generalized_hyperbola,
            df['r'],
            df['b_target_avg'],
            p0=[1.0, 1.0, 0.1],
            bounds=(0, np.inf),
            maxfev=10000
        )
        return popt
    except:
        return [np.nan, np.nan, np.nan]

# Fit and store per-AO curves
fit_results = []
ao_curves = []

# Loop over AOs
for ao_id, group in results_vi_df.groupby("AO_id"):
    try:
        # Fit hyperbola
        popt, _ = curve_fit(
            generalized_hyperbola,
            group['r'],
            group['b_target_avg'],
            p0=[1.0, 1.0, 0.1],
            bounds=(0, np.inf),
            maxfev=10000
        )

        # Predict using the fit
        predicted = generalized_hyperbola(group['r'], *popt)

        # Compute R²
        residuals = group['b_target_avg'] - predicted
        ss_res = np.sum(residuals ** 2)
        ss_tot = np.sum((group['b_target_avg'] - np.mean(group['b_target_avg'])) ** 2)
        r_squared = 1 - (ss_res / ss_tot)

        # Store parameters and R²
        fit_results.append({
            'AO_id': ao_id,
            'a': popt[0],
            'b': popt[1],
            'c': popt[2],
            'r_squared': r_squared
        })

        # Store predicted curve
        r_vals = np.linspace(group['r'].min(), group['r'].max(), 200)
        predicted_curve = generalized_hyperbola(r_vals, *popt)
        ao_curves.append(pd.DataFrame({
            'AO_id': ao_id,
            'r': r_vals,
            'predicted_b_target': predicted_curve
        }))

    except:
        continue  # skip AOs that fail to fit

fit_df = pd.DataFrame(fit_results)
curves_df = pd.concat(ao_curves, ignore_index=True)

# Group-level mean and CI
grouped = results_vi_df.groupby('r').agg(
    b_target_avg_mean=('b_target_avg', 'mean'),
    b_target_avg_se=('b_target_avg', sem)
).reset_index()
grouped['b_target_avg_ci'] = grouped['b_target_avg_se'] * t.ppf(0.975, df=len(results_df) - 1)

# Fit group-level curve
popt, _ = curve_fit(
    generalized_hyperbola,
    grouped['r'],
    grouped['b_target_avg_mean'],
    p0=[1.0, 1.0, 0.1],
    bounds=(0, np.inf)
)
r_vals = np.linspace(grouped['r'].min(), grouped['r'].max(), 200)
group_fit = generalized_hyperbola(r_vals, *popt)

# R² calculation for group-level fit
residuals = grouped['b_target_avg_mean'] - generalized_hyperbola(grouped['r'], *popt)
ss_res = np.sum(residuals ** 2)
ss_tot = np.sum((grouped['b_target_avg_mean'] - np.mean(grouped['b_target_avg_mean'])) ** 2)
r_squared = 1 - (ss_res / ss_tot)

# Plot
plt.figure(figsize=(8, 6))

# Plot all individual AO fits
for ao_id, ao_curve in curves_df.groupby("AO_id"):
    plt.plot(ao_curve['r'], ao_curve['predicted_b_target'], alpha=0.1, color='gray')

# Plot group-level fit
plt.plot(r_vals, group_fit, color='black', linewidth=2, label='Group Mean Fit')

# Overlay mean ± CI as white points with black outlines
plt.errorbar(
    grouped['r'],
    grouped['b_target_avg_mean'],
    yerr=grouped['b_target_avg_ci'],
    fmt='o',
    color='white',
    markersize=5,
    elinewidth=1.5,
    ecolor='black',
    capsize=4,
    markeredgecolor='black',
    label='Mean ± 95% CI'
)

# Add annotation for group-level fit parameters
plt.text(
    x=1,
    y=0.004,
    s=f"$a$ = {popt[0]:.2f}\n$b$ = {popt[1]:.2f}\n$c$ = {popt[2]:.2f}\n$r^2$ = {r_squared:.4f}",
    fontsize=20,
    ha='right'
)

# Style
plt.yscale("log")
plt.xlabel("Reinforcement Rate", fontsize=30)
plt.ylabel("Target Response Rate", fontsize=30)
plt.gca().yaxis.set_major_formatter(ticker.FormatStrFormatter('%.4f'))
plt.legend().remove()
plt.ylim(.0001, 300)
plt.grid(False)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
sns.despine()
plt.tight_layout()
plt.show()
No description has been provided for this image
In [14]:
fig, ax = plt.subplots(figsize=(1, 5))
sns.stripplot(fit_df['r_squared'], color='k', alpha=0.2)
sns.boxplot(fit_df['r_squared'], color='w', showfliers=False)
sns.despine(top=True, right=True)
plt.ylabel(r"$r^2$", fontsize=14, labelpad=12)
plt.ylim(0.8, 1)
plt.show()
fit_df['r_squared'].describe()
No description has been provided for this image
Out[14]:
r_squared
count 400.000000
mean 0.980894
std 0.073000
min 0.346293
25% 0.996896
50% 0.999256
75% 0.999776
max 0.999988

In [15]:
for vi_focus in results_vi_df['VI'].unique():
  df_vi = results_vi_df[results_vi_df['VI'] == vi_focus].copy()

  # Metric → Label + CI column mapping
  metrics_vi = {
      'arousal_avg': ('Arousal', 'arousal_ci'),
      'coupling_avg': ('Coupling', 'coupling_ci'),
      'b_target_avg': ('Target Response Rate', 'b_target_ci'),
      'b_total_avg': ('Total Response Rate', 'b_total_ci')
  }

  params = ['activation', 'delta', 'alpha', 'eta']

  # Create 4x4 subplot grid
  fig, axs = plt.subplots(4, 4, figsize=(14, 14), sharex='col', sharey=False)
  fig.subplots_adjust(hspace=0.4, wspace=0.4)

  # Plot each metric × parameter
  for row_idx, (metric_col, (metric_label, ci_col)) in enumerate(metrics_vi.items()):
      for col_idx, param in enumerate(params):
          ax = axs[row_idx, col_idx]

          # Get data for this combo
          sub_df = df_vi[[param, metric_col, ci_col]].copy()

          # Convert the x (param) to categorical codes for spacing
          sub_df['x_val'] = pd.Categorical(sub_df[param]).codes
          x_jittered = sub_df['x_val'] + np.random.uniform(-0.2, 0.2, size=len(sub_df))

          # Plot strip manually
          ax.scatter(x_jittered, sub_df[metric_col], color='black', alpha=0.7, s=15, zorder=2)

          # Plot error bars (CIs)
          ax.errorbar(
              x_jittered,
              sub_df[metric_col],
              yerr=sub_df[ci_col],
              fmt='none',
              ecolor='gray',
              elinewidth=0.8,
              capsize=2,
              alpha=0.6,
              zorder=1
          )

          # Add symbolic labels
          param_symbols = {
              'activation': 'Activation',
              'delta': r'$\delta$',
              'alpha': r'$\alpha$',
              'eta': r'$\eta$'
          }

          # Format axes
          if row_idx == 3:
              ax.set_xlabel(param_symbols[param], fontsize=20, labelpad=8, color='k')
              ax.set_xticks(range(len(sub_df[param].unique())))
              ax.set_xticklabels(sorted(sub_df[param].unique()), fontsize=10, color='k')
          else:
              ax.set_xlabel("")
              ax.set_xticklabels([])
              ax.tick_params(labelbottom=False)

          if col_idx == 0:
              ax.set_ylabel(metric_label, fontsize=20, labelpad=8, color='k')
          else:
              ax.set_ylabel("")
              ax.set_yticklabels([])

          ax.grid(False)
          sns.despine(top=True, right=True)

  plt.tight_layout(rect=[0, 0, 1, 0.97])
  plt.suptitle(f"VI Schedule = {vi_focus}", fontsize=20)
  plt.show()
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

A Range of Concurrent VRs¶

In [16]:
import numpy as np
import pandas as pd
from itertools import product
from tqdm import tqdm
from scipy.stats import sem, t

# --- MPR Model Equations ---
def update_arousal(prev_arousal: float, reinforcement: int, alpha: float, a: float) -> float:
    return alpha * (a * reinforcement) + (1 - alpha) * prev_arousal

def compute_total_response_rate(A: float, delta: float) -> float:
    return (1 / delta) * (A / (1 + A))

def compute_target_response_rate(A: float, delta: float, coupling: float) -> float:
    return coupling * (1 / delta) * (A / (1 + A))

def update_coupling(prev_coupling: float, was_reinforced: bool, learning_rate: float, proximity: float = 1.0) -> float:
    target = proximity if was_reinforced else 0.1
    return prev_coupling + learning_rate * (target - prev_coupling)

def update_step(prev_arousal, prev_coupling, reinforcement, a, delta, alpha, eta):
    A = update_arousal(prev_arousal, reinforcement, alpha, a)
    C = update_coupling(prev_coupling, was_reinforced=bool(reinforcement), learning_rate=eta)
    b_total = compute_total_response_rate(A, delta)
    b_target = compute_target_response_rate(A, delta, C)
    return A, C, b_total, b_target

# --- CI Helper ---
def ci95(data):
    n = len(data)
    return sem(data) * t.ppf(0.975, n - 1)

# --- Simulation with Persistent State Across Concurrent VR Pairs ---
def simulate_concurrent_mpr_sequence(vr_pairs, a, delta, alpha, eta, ao_id, num_steps=1000):
    A1, C1 = 0.6, 0.5
    A2, C2 = 0.6, 0.5
    all_results = []

    for pair_index, (vr1, vr2) in enumerate(vr_pairs):
        A1_hist, A2_hist = [], []
        C1_hist, C2_hist = [], []
        b1_hist, b2_hist = [], []

        for _ in range(num_steps):
            p1 = C1 / (C1 + C2)
            choice = np.random.choice([1, 2], p=[p1, 1 - p1])

            r1 = np.random.rand() < (1 / vr1)
            r2 = np.random.rand() < (1 / vr2)

            if choice == 1:
                A1, C1, _, _ = update_step(A1, C1, r1, a, delta, alpha, eta)
                A2, C2 = update_arousal(A2, 0, alpha, a), update_coupling(C2, False, eta)
            else:
                A2, C2, _, _ = update_step(A2, C2, r2, a, delta, alpha, eta)
                A1, C1 = update_arousal(A1, 0, alpha, a), update_coupling(C1, False, eta)

            A1_hist.append(A1)
            A2_hist.append(A2)
            C1_hist.append(C1)
            C2_hist.append(C2)
            b1_hist.append(1 if choice == 1 else 0)
            b2_hist.append(1 if choice == 2 else 0)

        all_results.append({
            'AO_id': ao_id,
            'pair_index': pair_index,
            'VR1': vr1,
            'VR2': vr2,
            'reinforcement_ratio': np.round((vr2 / vr1), 2),
            'arousal1_avg': np.mean(A1_hist),
            'arousal2_avg': np.mean(A2_hist),
            'coupling1_avg': np.mean(C1_hist),
            'coupling2_avg': np.mean(C2_hist),
            'choice1_avg': np.mean(b1_hist),
            'choice2_avg': np.mean(b2_hist),
            'arousal1_ci': ci95(A1_hist),
            'arousal2_ci': ci95(A2_hist),
            'coupling1_ci': ci95(C1_hist),
            'coupling2_ci': ci95(C2_hist),
            'choice1_ci': ci95(b1_hist),
            'choice2_ci': ci95(b2_hist),
            'activation': a,
            'delta': delta,
            'alpha': alpha,
            'eta': eta
        })

    return all_results

# Parameter Grid
deltas = [0.25, 0.50, 1.0, 2.0]
activations = [0.1, 0.2, 0.4, 0.8, 1.6]
alphas = [0.01, 0.03, 0.1, 0.3, 1.0]
etas = [0.01, 0.03, 0.3, 1.0]
vr_pairs = [(1, 9), (1, 3), (2, 2), (3, 1), (9, 1)]  # all pairs to be run per AO

# Run each AO through all VR pairs
concurrent_results = []
grid = list(product(activations, deltas, alphas, etas))

for ao_id, (a, delta, alpha, eta) in enumerate(tqdm(grid, desc="Running concurrent VR AOs")):
    sim_result_list = simulate_concurrent_mpr_sequence(vr_pairs, a, delta, alpha, eta, ao_id=ao_id)
    concurrent_results.extend(sim_result_list)

# Final DataFrame
results_concurrent_df = pd.DataFrame(concurrent_results)
Running concurrent VR AOs: 100%|██████████| 400/400 [01:06<00:00,  5.99it/s]
In [17]:
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
from scipy.stats import sem, t

# Step 1: Prep data
results_concurrent_df['response_ratio'] = results_concurrent_df['choice1_avg'] / results_concurrent_df['choice2_avg']
results_concurrent_df['log_resp_ratio'] = np.log(results_concurrent_df['response_ratio'])
results_concurrent_df['log_reinf_ratio'] = np.log(results_concurrent_df['reinforcement_ratio'])

# Step 2: Fit matching law to each AO
ao_fits = []
ao_lines = []

for ao_id, group in results_concurrent_df.groupby("AO_id"):
    if group['log_reinf_ratio'].nunique() < 2:
        continue
    X = group[['log_reinf_ratio']]
    y = group['log_resp_ratio']
    model = LinearRegression().fit(X, y)
    slope = model.coef_[0]
    intercept = model.intercept_
    r2 = r2_score(y, model.predict(X))

    ao_fits.append({
        'AO_id': ao_id,
        'slope': slope,
        'intercept': intercept,
        'r_squared': r2
    })

    # Store line for plotting
    x_range = np.linspace(X.min(), X.max(), 100)
    ao_lines.append(pd.DataFrame({
        'AO_id': ao_id,
        'x': x_range.squeeze(),
        'y': model.predict(pd.DataFrame({'log_reinf_ratio': x_range.squeeze()}))
    }))

fit_df = pd.DataFrame(ao_fits)
ao_lines_df = pd.concat(ao_lines)

# Step 3: Aggregate means across AOs per VR pair
grouped_means = results_concurrent_df.groupby(['VR1', 'VR2', 'reinforcement_ratio']).agg(
    log_resp_ratio_mean=('log_resp_ratio', 'mean'),
    log_resp_ratio_se=('log_resp_ratio', sem),
    log_reinf_ratio=('log_reinf_ratio', 'mean')
).reset_index()

grouped_means['log_resp_ratio_ci'] = grouped_means['log_resp_ratio_se'] * t.ppf(0.975, df=len(results_concurrent_df) - 1)

# Step 4: Fit matching law to group means
X_group = grouped_means[['log_reinf_ratio']]
y_group = grouped_means['log_resp_ratio_mean']

group_model = LinearRegression().fit(X_group, y_group)
group_slope = group_model.coef_[0]
group_intercept = group_model.intercept_
group_r2 = r2_score(y_group, group_model.predict(X_group))

x_range_group = np.linspace(X_group.min(), X_group.max(), 100)
y_pred_group = group_model.predict(pd.DataFrame({'log_reinf_ratio': x_range_group.squeeze()}))

# Step 5: Plot everything
plt.figure(figsize=(7, 7))

# Gray points and individual AO fits
sns.scatterplot(
    x='log_reinf_ratio',
    y='log_resp_ratio',
    data=results_concurrent_df,
    color='gray',
    alpha=0.3,
    label='Individual AO'
)

for _, ao_line in ao_lines_df.groupby('AO_id'):
    plt.plot(ao_line['x'], ao_line['y'], color='gray', alpha=0.1)

# Mean ± 95% CI
plt.errorbar(
    x=grouped_means['log_reinf_ratio'],
    y=grouped_means['log_resp_ratio_mean'],
    yerr=grouped_means['log_resp_ratio_ci'],
    fmt='o',
    color='white',
    markersize=5,
    elinewidth=1.5,
    ecolor='black',
    capsize=4,
    markeredgecolor='black',
    label='Mean ± 95% CI'
)

# Group fit
plt.plot(x_range_group, y_pred_group, color='black', linewidth=2, label='Group Mean Fit')

# Identity line
plt.plot([-3, 3], [-3, 3], linestyle='--', color='gray')

# Labels and annotation
plt.xlabel(r'$\mathit{log}\left(\frac{R_1}{R_2}\right)$', fontsize=24, labelpad=12)
plt.ylabel(r'$\mathit{log}\left(\frac{B_1}{B_2}\right)$', fontsize=24, labelpad=12)

plt.text(
    1.5, -2.5,
    f"Sensitivity = {group_slope:.2f}\nBias = {group_intercept:.2f}\n$r^2$ = {group_r2:.2f}",
    fontsize=20,
    ha='center'
)

plt.legend(frameon=False, fontsize=16)
sns.despine(top=True, right=True)
plt.grid(False)
plt.tight_layout()
plt.show()
No description has been provided for this image
In [18]:
fig, ax = plt.subplots(figsize=(1, 5))
sns.stripplot(fit_df['r_squared'], color='k', alpha=0.2)
sns.boxplot(fit_df['r_squared'], color='w', showfliers=False)
sns.despine(top=True, right=True)
plt.ylabel(r"$r^2$", fontsize=14, labelpad=12)
plt.ylim(0.8, 1)
plt.show()

fit_df['r_squared'].describe()
No description has been provided for this image
Out[18]:
r_squared
count 400.000000
mean 0.927126
std 0.026913
min 0.822572
25% 0.909812
50% 0.930033
75% 0.947136
max 0.989533

In [19]:
from scipy.optimize import least_squares
from sklearn.metrics import r2_score
from scipy.stats import sem, t
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Define McDowell concurrent schedule equations
def B1_eq(r1, r2, a, b1, c1, c2):
    return (b1 * (r1 ** a)) / (r1 ** a + (c1 / c2) * (r2 ** a) + c1)

def B2_eq(r1, r2, a, b2, c1, c2):
    return (b2 * (r2 ** a)) / ((c2 / c1) * (r1 ** a) + (r2 ** a) + c2)

# Joint residual function for least squares
def joint_residuals(params, r1, r2, B1_obs, B2_obs):
    a, b1, b2, c1, c2 = params
    B1_pred = B1_eq(r1, r2, a, b1, c1, c2)
    B2_pred = B2_eq(r1, r2, a, b2, c1, c2)
    return np.concatenate([(B1_pred - B1_obs), (B2_pred - B2_obs)])

# Prepare data
results_df = results_concurrent_df.copy()
results_df['r1'] = 1 / results_df['VR1']
results_df['r2'] = 1 / results_df['VR2']

# Fit per AO
fit_results = []
ao_fits = []

for ao_id, group in results_df.groupby("AO_id"):
    if group.shape[0] < 5:
        continue
    try:
        x0 = [1.0, 1.0, 1.0, 0.1, 0.1]
        res = least_squares(
            joint_residuals,
            x0=x0,
            bounds=(0, np.inf),
            args=(group['r1'].values, group['r2'].values, group['choice1_avg'].values, group['choice2_avg'].values)
        )
        a, b1, b2, c1, c2 = res.x

        # Compute predictions and R²
        B1_pred = B1_eq(group['r1'], group['r2'], a, b1, c1, c2)
        B2_pred = B2_eq(group['r1'], group['r2'], a, b2, c1, c2)
        r2_B1 = r2_score(group['choice1_avg'], B1_pred)
        r2_B2 = r2_score(group['choice2_avg'], B2_pred)

        fit_results.append({
            'AO_id': ao_id,
            'a': a, 'b1': b1, 'b2': b2, 'c1': c1, 'c2': c2,
            'r2_B1': r2_B1, 'r2_B2': r2_B2
        })

        ao_fits.append({
            'AO_id': ao_id,
            'r1': group['r1'].values,
            'r2': group['r2'].values,
            'B1_pred': B1_pred,
            'B2_pred': B2_pred
        })
    except:
        continue

fit_df = pd.DataFrame(fit_results)

# Group means and CI
grouped = results_df.groupby(['VR1', 'VR2']).agg({
    'r1': 'mean',
    'r2': 'mean',
    'choice1_avg': ['mean', sem],
    'choice2_avg': ['mean', sem]
}).reset_index()

grouped.columns = ['VR1', 'VR2', 'r1', 'r2', 'B1_mean', 'B1_se', 'B2_mean', 'B2_se']
grouped['B1_ci'] = grouped['B1_se'] * t.ppf(0.975, df=len(results_df) - 1)
grouped['B2_ci'] = grouped['B2_se'] * t.ppf(0.975, df=len(results_df) - 1)

# Fit model to group means
x0 = [1.0, 1.0, 1.0, 0.1, 0.1]
res = least_squares(
    joint_residuals,
    x0=x0,
    bounds=(0, np.inf),
    args=(grouped['r1'].values, grouped['r2'].values, grouped['B1_mean'].values, grouped['B2_mean'].values)
)
a, b1, b2, c1, c2 = res.x

grouped['B1_pred'] = B1_eq(grouped['r1'], grouped['r2'], a, b1, c1, c2)
grouped['B2_pred'] = B2_eq(grouped['r1'], grouped['r2'], a, b2, c1, c2)

r2_B1 = r2_score(grouped['B1_mean'], grouped['B1_pred'])
r2_B2 = r2_score(grouped['B2_mean'], grouped['B2_pred'])

# Plot
plt.figure(figsize=(7, 6))

# All AO lines in gray
for fit in ao_fits:
    plt.plot(fit['r1'], fit['B1_pred'], color='gray', alpha=0.1, zorder=1)
    plt.plot(fit['r2'], fit['B2_pred'], color='gray', alpha=0.1, linestyle='--', zorder=1)

# Mean data points with CI
sns.scatterplot(x=grouped['r1'], y=grouped['B1_mean'], color='black', edgecolor='black', s=60,
                label='B1 mean ± 95% CI', zorder=10)
plt.errorbar(grouped['r1'], grouped['B1_mean'], yerr=grouped['B1_ci'], fmt='none',
             ecolor='black', capsize=4, zorder=10)

sns.scatterplot(x=grouped['r2'], y=grouped['B2_mean'], color='white', edgecolor='black', s=60,
                label='B2 mean ± 95% CI', zorder=10)
plt.errorbar(grouped['r2'], grouped['B2_mean'], yerr=grouped['B2_ci'], fmt='none',
             ecolor='gray', capsize=4, zorder=10)

# Group fits in bold
plt.plot(grouped['r1'], grouped['B1_pred'], color='black', linewidth=2,
         label='B1 fit', zorder=9)
plt.plot(grouped['r2'], grouped['B2_pred'], color='gray', linewidth=2, linestyle='--',
         label='B2 fit', zorder=9)

# Annotation
plt.text(
    1, 0.05,
    f"$a$ = {a:.2f}\n"
    f"$b_1$ = {b1:.2f}, $b_2$ = {b2:.2f}\n"
    f"$c_1$ = {c1:.2f}, $c_2$ = {c2:.2f}\n"
    f"$r^2_{{B1}}$ = {r2_B1:.2f}\n"
    f"$r^2_{{B2}}$ = {r2_B2:.2f}",
    fontsize=18,
    ha='right'
)

plt.xlabel('Reinforcement Rate', fontsize=26, labelpad=12)
plt.ylabel('Response Rate', fontsize=26, labelpad=12)
plt.yticks(fontsize=16)
plt.xticks(fontsize=16)
plt.legend(frameon=False, fontsize=12)
sns.despine(top=True, right=True)
plt.tight_layout()
plt.title('Concurrent VR Schedule Fit to\nGeneralized Hyperbolic Equations\nPer AO + Group-Level', fontsize=14)
plt.show()
No description has been provided for this image
In [20]:
# Prepare data for concurrent schedule version of the same 4x4 grid plot
params = ['activation', 'delta', 'alpha', 'eta']

# Mapping for each lever's metrics
metrics_concurrent_1 = {
    'arousal1_avg': (r'Arousal ($\it{B}_{1}$)', 'arousal1_ci'),
    'coupling1_avg': (r'Coupling ($\it{B}_{1}$)', 'coupling1_ci'),
    'choice1_avg': (r'Choice Rate ($\it{B}_{1}$)', 'choice1_ci'),
}

metrics_concurrent_2 = {
    'arousal2_avg': (r'Arousal ($\it{B}_{2}$)', 'arousal2_ci'),
    'coupling2_avg': (r'Coupling ($\it{B}_{2}$)', 'coupling2_ci'),
    'choice2_avg': (r'Choice Rate ($\it{B}_{2}$)', 'choice2_ci'),
}
In [21]:
# Create subplot function
def plot_concurrent_grid(df, metrics, title):
    fig, axs = plt.subplots(3, 4, figsize=(16, 14), sharex='col', sharey=False)
    fig.subplots_adjust(hspace=0.4, wspace=0.4)

    for row_idx, (metric_col, (metric_label, ci_col)) in enumerate(metrics.items()):
        for col_idx, param in enumerate(params):
            ax = axs[row_idx, col_idx]
            if metric_col == '':  # Skip empty row
                ax.axis('off')
                continue

            sub_df = df[[param, metric_col, ci_col]].copy()
            sub_df['x_val'] = pd.Categorical(sub_df[param]).codes
            x_jittered = sub_df['x_val'] + np.random.uniform(-0.2, 0.2, size=len(sub_df))

            ax.scatter(x_jittered, sub_df[metric_col], color='black', alpha=0.7, s=15, zorder=2)
            ax.errorbar(
                x_jittered,
                sub_df[metric_col],
                yerr=sub_df[ci_col],
                fmt='none',
                ecolor='gray',
                elinewidth=0.8,
                capsize=2,
                alpha=0.6,
                zorder=1
            )

            # Add symbolic labels
            param_symbols = {
                'activation': 'Activation',
                'delta': r'$\delta$',
                'alpha': r'$\alpha$',
                'eta': r'$\eta$'
            }

            # Format axes
            if row_idx == 2:
                ax.set_xlabel(param_symbols[param], fontsize=20, labelpad=8, color='k')
                ax.set_xticks(range(len(sub_df[param].unique())))
                ax.set_xticklabels(sorted(sub_df[param].unique()), fontsize=10, color='k')
            else:
                ax.set_xlabel("")
                ax.set_xticklabels([])
                ax.tick_params(labelbottom=False)

            if col_idx == 0:
                ax.set_ylabel(metric_label, fontsize=20, labelpad=8, color='k')
            else:
                ax.set_ylabel("")
                ax.set_yticklabels([])

            ax.grid(False)
            sns.despine(top=True, right=True)

    plt.suptitle(title, fontsize=30)
    plt.tight_layout(rect=[0, 0, 1, 0.97])
    plt.show()
In [22]:
# Plot Lever 1
plot_concurrent_grid(results_concurrent_df, metrics_concurrent_1, r"Concurrent Schedule: $\it{B}_{1}$")
No description has been provided for this image
In [23]:
# Plot Lever 2
plot_concurrent_grid(results_concurrent_df, metrics_concurrent_2, r"Concurrent Schedule: $\it{B}_{2}$")
No description has been provided for this image

A Range of Concurrent VIs¶

In [24]:
import numpy as np
import pandas as pd
from itertools import product
from scipy.stats import sem, t
from tqdm import tqdm

# --- MPR Model Equations ---
def update_arousal(prev_arousal, reinforcement, alpha, a):
    return alpha * (a * reinforcement) + (1 - alpha) * prev_arousal

def compute_total_response_rate(A, delta):
    return (1 / delta) * (A / (1 + A))

def compute_target_response_rate(A, delta, coupling):
    return coupling * (1 / delta) * (A / (1 + A))

def update_coupling(prev_coupling, was_reinforced, learning_rate, proximity=1.0):
    target = proximity if was_reinforced else 0.1
    return prev_coupling + learning_rate * (target - prev_coupling)

def update_step(prev_arousal, prev_coupling, reinforcement, a, delta, alpha, eta):
    A = update_arousal(prev_arousal, reinforcement, alpha, a)
    C = update_coupling(prev_coupling, bool(reinforcement), eta)
    b_total = compute_total_response_rate(A, delta)
    b_target = compute_target_response_rate(A, delta, C)
    return A, C, b_total, b_target

# --- Concurrent VI Sequence Simulation (AO across all VI pairs) ---
def simulate_concurrent_mpr_vi_sequence(vi_pairs, a, delta, alpha, eta, ao_id, num_steps=1000, avg_peck_rate=2.5):
    A1, C1 = 0.6, 0.5
    A2, C2 = 0.6, 0.5
    results = []

    for pair_index, (vi1, vi2) in enumerate(vi_pairs):
        A1_hist, A2_hist, C1_hist, C2_hist = [], [], [], []
        b1_hist, b2_hist = [], []

        avg_irt = 1.0 / avg_peck_rate
        vi_times1 = np.cumsum(np.random.exponential(scale=vi1, size=num_steps))
        vi_times2 = np.cumsum(np.random.exponential(scale=vi2, size=num_steps))
        t1_index, t2_index = 0, 0
        next_r1, next_r2 = vi_times1[0], vi_times2[0]
        current_time = 0

        for _ in range(num_steps):
            current_time += avg_irt
            r1 = int(current_time >= next_r1)
            r2 = int(current_time >= next_r2)

            if r1 and t1_index + 1 < len(vi_times1):
                t1_index += 1
                next_r1 = vi_times1[t1_index]
            if r2 and t2_index + 1 < len(vi_times2):
                t2_index += 1
                next_r2 = vi_times2[t2_index]

            # Choice
            p1 = C1 / (C1 + C2)
            choice = np.random.choice([1, 2], p=[p1, 1 - p1])

            if choice == 1:
                A1, C1, _, _ = update_step(A1, C1, r1, a, delta, alpha, eta)
                A2, C2 = update_arousal(A2, 0, alpha, a), update_coupling(C2, False, eta)
                b1_hist.append(1)
                b2_hist.append(0)
            else:
                A2, C2, _, _ = update_step(A2, C2, r2, a, delta, alpha, eta)
                A1, C1 = update_arousal(A1, 0, alpha, a), update_coupling(C1, False, eta)
                b1_hist.append(0)
                b2_hist.append(1)

            A1_hist.append(A1)
            A2_hist.append(A2)
            C1_hist.append(C1)
            C2_hist.append(C2)

        def ci95(data):
            return sem(data) * t.ppf(0.975, len(data) - 1)

        results.append({
            'AO_id': ao_id,
            'pair_index': pair_index,
            'VI1': vi1,
            'VI2': vi2,
            'reinforcement_ratio': round(vi2 / vi1, 2),
            'arousal1_avg': np.mean(A1_hist),
            'arousal2_avg': np.mean(A2_hist),
            'coupling1_avg': np.mean(C1_hist),
            'coupling2_avg': np.mean(C2_hist),
            'choice1_avg': np.mean(b1_hist),
            'choice2_avg': np.mean(b2_hist),
            'arousal1_ci': ci95(A1_hist),
            'arousal2_ci': ci95(A2_hist),
            'coupling1_ci': ci95(C1_hist),
            'coupling2_ci': ci95(C2_hist),
            'choice1_ci': ci95(b1_hist),
            'choice2_ci': ci95(b2_hist),
            'activation': a,
            'delta': delta,
            'alpha': alpha,
            'eta': eta
        })

    return results

# Params
deltas = [0.25, 0.50, 1.0, 2.0]
activations = [0.1, 0.2, 0.4, 0.8, 1.6]
alphas = [0.01, 0.03, 0.1, 0.3, 1.0]
etas = [0.01, 0.03, 0.3, 1.0]
vi_pairs = [(1, 9), (1, 3), (2, 2), (3, 1), (9, 1)]

# Simulation
results_concurrent_vi = []
grid = list(product(activations, deltas, alphas, etas))

for ao_id, (a, delta, alpha, eta) in enumerate(tqdm(grid, desc="Concurrent VI AO sims")):
    shuffled_pairs = np.random.permutation(vi_pairs)
    result = simulate_concurrent_mpr_vi_sequence(shuffled_pairs, a, delta, alpha, eta, ao_id=ao_id)
    results_concurrent_vi.extend(result)

# Final DataFrame
results_concurrent_vi_df = pd.DataFrame(results_concurrent_vi)
Concurrent VI AO sims: 100%|██████████| 400/400 [01:02<00:00,  6.41it/s]
In [25]:
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
from scipy.stats import sem, t

# Step 1: Calculate log ratios
results_concurrent_vi_df['response_ratio'] = results_concurrent_vi_df['choice1_avg'] / results_concurrent_vi_df['choice2_avg']
results_concurrent_vi_df['log_resp_ratio'] = np.log(results_concurrent_vi_df['response_ratio'])
results_concurrent_vi_df['log_reinf_ratio'] = np.log(results_concurrent_vi_df['reinforcement_ratio'])

# Step 2: Fit matching law per AO
ao_fits = []
ao_lines = []

for ao_id, group in results_concurrent_vi_df.groupby("AO_id"):
    if group['log_reinf_ratio'].nunique() < 2:
        continue
    X = group[['log_reinf_ratio']]
    y = group['log_resp_ratio']
    model = LinearRegression().fit(X, y)
    slope = model.coef_[0]
    intercept = model.intercept_
    r2 = r2_score(y, model.predict(X))

    ao_fits.append({'AO_id': ao_id, 'slope': slope, 'intercept': intercept, 'r_squared': r2})

    x_range = np.linspace(X.min(), X.max(), 100)
    ao_lines.append(pd.DataFrame({
        'AO_id': ao_id,
        'x': x_range.squeeze(),
        'y': model.predict(pd.DataFrame({'log_reinf_ratio': x_range.squeeze()}))
    }))

fit_df = pd.DataFrame(ao_fits)
ao_lines_df = pd.concat(ao_lines)

# Step 3: Group-level mean + CI
grouped_means = results_concurrent_vi_df.groupby(['VI1', 'VI2', 'reinforcement_ratio']).agg(
    log_resp_ratio_mean=('log_resp_ratio', 'mean'),
    log_resp_ratio_se=('log_resp_ratio', sem),
    log_reinf_ratio=('log_reinf_ratio', 'mean')
).reset_index()

grouped_means['log_resp_ratio_ci'] = grouped_means['log_resp_ratio_se'] * t.ppf(0.975, df=len(results_concurrent_vi_df) - 1)

# Step 4: Fit group-level GML
X_group = grouped_means[['log_reinf_ratio']]
y_group = grouped_means['log_resp_ratio_mean']
model_group = LinearRegression().fit(X_group, y_group)
slope = model_group.coef_[0]
intercept = model_group.intercept_
r2 = r2_score(y_group, model_group.predict(X_group))

x_range_group = np.linspace(X_group.min(), X_group.max(), 100)
y_pred_group = model_group.predict(pd.DataFrame({'log_reinf_ratio': x_range_group.squeeze()}))

# Step 5: Plot
plt.figure(figsize=(7, 7))

# Individual AO data and fits
sns.scatterplot(
    x='log_reinf_ratio',
    y='log_resp_ratio',
    data=results_concurrent_vi_df,
    color='gray',
    alpha=0.3,
    label='Individual AO',
    zorder=1
)

for _, ao_line in ao_lines_df.groupby('AO_id'):
    plt.plot(ao_line['x'], ao_line['y'], color='gray', alpha=0.1, zorder=1)

# Group means with CI
plt.errorbar(
    x=grouped_means['log_reinf_ratio'],
    y=grouped_means['log_resp_ratio_mean'],
    yerr=grouped_means['log_resp_ratio_ci'],
    fmt='o',
    color='white',
    markersize=6,
    elinewidth=1.5,
    ecolor='black',
    capsize=4,
    markeredgecolor='black',
    label='Mean ± 95% CI',
    zorder=10
)

# Group-level GML line
plt.plot(x_range_group, y_pred_group, color='black', linewidth=2, label='Mean GML Fit', zorder=9)
plt.plot([-3, 3], [-3, 3], linestyle='--', color='gray', zorder=1)

# Labels and annotation
plt.xlabel(r'$\mathit{log}\left(\frac{R_1}{R_2}\right)$', fontsize=22, labelpad=12)
plt.ylabel(r'$\mathit{log}\left(\frac{B_1}{B_2}\right)$', fontsize=22, labelpad=12)

plt.text(
    1.5, -2.5,
    f"Sensitivity = {slope:.2f}\nBias = {intercept:.2f}\n$r^2 = {r2:.2f}$",
    fontsize=20,
    ha='center'
)

plt.grid(False)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.legend(frameon=False, fontsize=16)
sns.despine(top=True, right=True)
plt.tight_layout()
plt.show()
No description has been provided for this image
In [26]:
fig, ax = plt.subplots(figsize=(1, 5))
sns.stripplot(fit_df['r_squared'], color='k', alpha=0.2)
sns.boxplot(fit_df['r_squared'], color='w', showfliers=False)
sns.despine(top=True, right=True)
plt.ylabel(r"$r^2$", fontsize=14, labelpad=12)
plt.ylim(0.8, 1)
plt.show()

fit_df['r_squared'].describe()
No description has been provided for this image
Out[26]:
r_squared
count 400.000000
mean 0.929968
std 0.047380
min 0.692368
25% 0.904959
50% 0.941540
75% 0.962126
max 0.999509

In [27]:
from scipy.optimize import least_squares
from sklearn.metrics import r2_score
from scipy.stats import sem, t
import matplotlib.pyplot as plt
import seaborn as sns

# --- McDowell equations ---
def B1_eq(r1, r2, a, b1, c1, c2):
    return (b1 * (r1 ** a)) / (r1 ** a + (c1 / c2) * (r2 ** a) + c1)

def B2_eq(r1, r2, a, b2, c1, c2):
    return (b2 * (r2 ** a)) / ((c2 / c1) * (r1 ** a) + (r2 ** a) + c2)

def joint_residuals(params, r1, r2, B1_obs, B2_obs):
    a, b1, b2, c1, c2 = params
    B1_pred = B1_eq(r1, r2, a, b1, c1, c2)
    B2_pred = B2_eq(r1, r2, a, b2, c1, c2)
    return np.concatenate([(B1_pred - B1_obs), (B2_pred - B2_obs)])

results_df = results_concurrent_vi_df.copy()
results_df['r1'] = 1 / results_df['VI1']
results_df['r2'] = 1 / results_df['VI2']

fit_results = []
ao_fits = []

for ao_id, group in results_df.groupby("AO_id"):
    if group.shape[0] < 5:
        continue
    try:
        x0 = [1.0, 1.0, 1.0, 0.1, 0.1]
        res = least_squares(
            joint_residuals,
            x0=x0,
            bounds=(0, np.inf),
            args=(group['r1'].values, group['r2'].values,
                  group['choice1_avg'].values, group['choice2_avg'].values)
        )
        a, b1, b2, c1, c2 = res.x

        B1_pred = B1_eq(group['r1'], group['r2'], a, b1, c1, c2)
        B2_pred = B2_eq(group['r1'], group['r2'], a, b2, c1, c2)
        r2_B1 = r2_score(group['choice1_avg'], B1_pred)
        r2_B2 = r2_score(group['choice2_avg'], B2_pred)

        fit_results.append({
            'AO_id': ao_id,
            'a': a, 'b1': b1, 'b2': b2,
            'c1': c1, 'c2': c2,
            'r2_B1': r2_B1,
            'r2_B2': r2_B2
        })

        ao_fits.append(pd.DataFrame({
            'AO_id': ao_id,
            'r1': group['r1'].values,
            'r2': group['r2'].values,
            'B1_pred': B1_pred,
            'B2_pred': B2_pred
        }))
    except:
        continue

fit_df = pd.DataFrame(fit_results)
ao_fits_df = pd.concat(ao_fits)

grouped = results_df.groupby(['VI1', 'VI2']).agg({
    'r1': 'mean',
    'r2': 'mean',
    'choice1_avg': ['mean', sem],
    'choice2_avg': ['mean', sem]
}).reset_index()

grouped.columns = ['VI1', 'VI2', 'r1', 'r2', 'B1_mean', 'B1_se', 'B2_mean', 'B2_se']
grouped['B1_ci'] = grouped['B1_se'] * t.ppf(0.975, df=len(results_df)-1)
grouped['B2_ci'] = grouped['B2_se'] * t.ppf(0.975, df=len(results_df)-1)

# Fit to group means
x0 = [1.0, 1.0, 1.0, 0.1, 0.1]
res = least_squares(
    joint_residuals,
    x0=x0,
    bounds=(0, np.inf),
    args=(grouped['r1'].values, grouped['r2'].values,
          grouped['B1_mean'].values, grouped['B2_mean'].values)
)
a, b1, b2, c1, c2 = res.x

grouped['B1_pred'] = B1_eq(grouped['r1'], grouped['r2'], a, b1, c1, c2)
grouped['B2_pred'] = B2_eq(grouped['r1'], grouped['r2'], a, b2, c1, c2)

r2_B1 = r2_score(grouped['B1_mean'], grouped['B1_pred'])
r2_B2 = r2_score(grouped['B2_mean'], grouped['B2_pred'])

plt.figure(figsize=(7, 6))

# Individual AO fits (gray)
for ao_id, ao_data in ao_fits_df.groupby('AO_id'):
    plt.plot(ao_data['r1'], ao_data['B1_pred'], color='gray', alpha=0.1, zorder=1)
    plt.plot(ao_data['r2'], ao_data['B2_pred'], color='gray', alpha=0.1, linestyle='--', zorder=1)

# Group means
sns.scatterplot(x=grouped['r1'], y=grouped['B1_mean'],
                color='black', edgecolor='black', s=60,
                label='B1 mean ± 95% CI', zorder=10)
plt.errorbar(grouped['r1'], grouped['B1_mean'], yerr=grouped['B1_ci'],
             fmt='none', ecolor='black', capsize=4, zorder=10)

sns.scatterplot(x=grouped['r2'], y=grouped['B2_mean'],
                color='white', edgecolor='black', s=60,
                label='B2 mean ± 95% CI', zorder=10)
plt.errorbar(grouped['r2'], grouped['B2_mean'], yerr=grouped['B2_ci'],
             fmt='none', ecolor='gray', capsize=4, zorder=10)

# Group fit lines
plt.plot(grouped['r1'], grouped['B1_pred'], color='black', linewidth=2, label='B1 fit', zorder=9)
plt.plot(grouped['r2'], grouped['B2_pred'], color='gray', linewidth=2, linestyle='--', label='B2 fit', zorder=9)

# Annotation
plt.text(
    1, 0.2,
    f"$a$ = {a:.2f}\n"
    f"$b_1$ = {b1:.2f}, $b_2$ = {b2:.2f}\n"
    f"$c_1$ = {c1:.2f}, $c_2$ = {c2:.2f}\n"
    f"$r^2_{{B1}}$ = {r2_B1:.2f}\n"
    f"$r^2_{{B2}}$ = {r2_B2:.2f}",
    fontsize=14,
    ha='right'
)

plt.xlabel('Reinforcement Rate', fontsize=20)
plt.ylabel('Response Rate', fontsize=20)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.legend(frameon=False, fontsize=12)
sns.despine(top=True, right=True)
plt.tight_layout()
plt.title('Concurrent VI Schedule Fit to\nGeneralized Hyperbolic Equations', fontsize=14)
plt.show()
No description has been provided for this image
In [28]:
# Prepare data for concurrent schedule version of the same 4x4 grid plot
params = ['activation', 'delta', 'alpha', 'eta']

# Mapping for each lever's metrics
metrics_concurrent_1 = {
    'arousal1_avg': (r'Arousal ($\it{B}_{1}$)', 'arousal1_ci'),
    'coupling1_avg': (r'Coupling ($\it{B}_{1}$)', 'coupling1_ci'),
    'choice1_avg': (r'Choice Rate ($\it{B}_{1}$)', 'choice1_ci'),
}

metrics_concurrent_2 = {
    'arousal2_avg': (r'Arousal ($\it{B}_{2}$)', 'arousal2_ci'),
    'coupling2_avg': (r'Coupling ($\it{B}_{2}$)', 'coupling2_ci'),
    'choice2_avg': (r'Choice Rate ($\it{B}_{2}$)', 'choice2_ci'),
}
In [29]:
# Create subplot function
def plot_concurrent_grid(df, metrics, title):
    fig, axs = plt.subplots(3, 4, figsize=(16, 14), sharex='col', sharey=False)
    fig.subplots_adjust(hspace=0.4, wspace=0.4)

    for row_idx, (metric_col, (metric_label, ci_col)) in enumerate(metrics.items()):
        for col_idx, param in enumerate(params):
            ax = axs[row_idx, col_idx]
            if metric_col == '':  # Skip empty row
                ax.axis('off')
                continue

            sub_df = df[[param, metric_col, ci_col]].copy()
            sub_df['x_val'] = pd.Categorical(sub_df[param]).codes
            x_jittered = sub_df['x_val'] + np.random.uniform(-0.2, 0.2, size=len(sub_df))

            ax.scatter(x_jittered, sub_df[metric_col], color='black', alpha=0.7, s=15, zorder=2)
            ax.errorbar(
                x_jittered,
                sub_df[metric_col],
                yerr=sub_df[ci_col],
                fmt='none',
                ecolor='gray',
                elinewidth=0.8,
                capsize=2,
                alpha=0.6,
                zorder=1
            )

            # Add symbolic labels
            param_symbols = {
                'activation': 'Activation',
                'delta': r'$\delta$',
                'alpha': r'$\alpha$',
                'eta': r'$\eta$'
            }

            if row_idx == 2:
                ax.set_xlabel(param_symbols[param], fontsize=20, labelpad=8, color='k')
                ax.set_xticks(range(len(sub_df[param].unique())))
                ax.set_xticklabels(sorted(sub_df[param].unique()), fontsize=10, color='k')
            else:
                ax.set_xlabel("")
                ax.set_xticklabels([])
                ax.tick_params(labelbottom=False)

            if col_idx == 0:
                ax.set_ylabel(metric_label, fontsize=20, labelpad=8, color='k')
            else:
                ax.set_ylabel("")
                ax.set_yticklabels([])

            ax.grid(False)
            sns.despine(top=True, right=True)

    plt.suptitle(title, fontsize=30)
    plt.tight_layout(rect=[0, 0, 1, 0.97])
    plt.show()
In [30]:
# Plot Lever 1
plot_concurrent_grid(results_concurrent_vi_df, metrics_concurrent_1, r"Concurrent Schedule: $\it{B}_{1}$")
No description has been provided for this image
In [31]:
# Plot Lever 2
plot_concurrent_grid(results_concurrent_vi_df, metrics_concurrent_2, r"Concurrent Schedule: $\it{B}_{2}$")
No description has been provided for this image